import hail as hl
from hail import ir
from hail.expr import expr_any, expr_array, expr_bool, expr_interval, expr_locus, expr_str
from hail.matrixtable import MatrixTable
from hail.table import Table
from hail.typecheck import dictof, enumeration, func_spec, nullable, oneof, sequenceof, typecheck
from hail.utils.java import Env, info, warning
from hail.utils.misc import new_temp_file, wrap_to_list
from hail.vds.variant_dataset import VariantDataset


def write_variant_datasets(vdss, paths, *, overwrite=False, stage_locally=False, codec_spec=None):
    """Write many `vdses` to their corresponding path in `paths`."""
    ref_writer = ir.MatrixNativeMultiWriter(
        [f"{p}/reference_data" for p in paths], overwrite, stage_locally, codec_spec
    )
    var_writer = ir.MatrixNativeMultiWriter([f"{p}/variant_data" for p in paths], overwrite, stage_locally, codec_spec)
    Env.backend().execute(ir.MatrixMultiWrite([vds.reference_data._mir for vds in vdss], ref_writer))
    Env.backend().execute(ir.MatrixMultiWrite([vds.variant_data._mir for vds in vdss], var_writer))


@typecheck(vds=VariantDataset)
def to_dense_mt(vds: 'VariantDataset') -> 'MatrixTable':
    """Creates a single, dense :class:`.MatrixTable` from the split
    :class:`.VariantDataset` representation.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset in dense MatrixTable representation.
    """

    # NOTE: There is a strong assumption that ref block LEN does not extend
    # past the end of a contig. That is bad data and we won't correct for it.
    # Garbage in, garbage out.
    ref = vds.reference_data
    ref = ref.annotate_rows(_locus_global_pos=ref.locus.global_position())
    ref = ref.transmute_entries(_END_GLOBAL=ref._locus_global_pos + ref.LEN - 1)

    to_drop = 'alleles', 'rsid', 'ref_allele', '_locus_global_pos'
    ref = ref.drop(*(x for x in to_drop if x in ref.row))
    var = vds.variant_data
    refl = ref.localize_entries('_ref_entries')
    varl = var.localize_entries('_var_entries', '_var_cols')
    varl = varl.annotate(_variant_defined=True)
    joined = varl.key_by('locus').join(refl, how='outer')
    dr = joined.annotate(
        dense_ref=hl.or_missing(
            joined._variant_defined, hl.scan._densify(hl.len(joined._var_cols), joined._ref_entries)
        )
    )
    dr = dr.filter(dr._variant_defined)

    def coalesce_join(ref, var):
        call_field = 'GT' if 'GT' in var else 'LGT'
        assert call_field in var, var.dtype

        if call_field not in ref:
            ref_call_field = 'GT' if 'GT' in ref else 'LGT'
            if ref_call_field not in ref:
                ref = ref.annotate(**{call_field: hl.call(0, 0)})
            else:
                ref = ref.annotate(**{call_field: ref[ref_call_field]})

        # call_field is now in both ref and var
        ref_set, var_set = set(ref.dtype), set(var.dtype)
        shared_fields, var_fields = var_set & ref_set, var_set - ref_set

        return hl.if_else(
            hl.is_defined(var),
            var.select(*shared_fields, *var_fields),
            ref.select(*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields}),
        )

    dr = dr.annotate(
        _dense=hl.rbind(
            dr._ref_entries,
            lambda refs_at_this_row: hl.enumerate(hl.zip(dr._var_entries, dr.dense_ref)).map(
                lambda tup: coalesce_join(
                    hl.coalesce(
                        refs_at_this_row[tup[0]],
                        hl.or_missing(tup[1][1]._END_GLOBAL >= dr.locus.global_position(), tup[1][1]),
                    ),
                    tup[1][0],
                )
            ),
        ),
    )

    dr = dr._key_by_assert_sorted('locus', 'alleles')
    fields_to_drop = ['_var_entries', '_ref_entries', 'dense_ref', '_variant_defined']

    if hl.vds.VariantDataset.ref_block_max_length_field in dr.globals:
        fields_to_drop.append(hl.vds.VariantDataset.ref_block_max_length_field)

    if 'ref_allele' in dr.row:
        fields_to_drop.append('ref_allele')
    dr = dr.drop(*fields_to_drop)
    return dr._unlocalize_entries('_dense', '_var_cols', list(var.col_key))


@typecheck(vds=VariantDataset, ref_allele_function=nullable(func_spec(1, expr_str)))
def to_merged_sparse_mt(vds: 'VariantDataset', *, ref_allele_function=None) -> 'MatrixTable':
    """Creates a single, merged sparse :class:`.MatrixTable` from the split
    :class:`.VariantDataset` representation.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.

    Returns
    -------
    :class:`.MatrixTable`
        Dataset in the merged sparse MatrixTable representation.
    """
    rht = vds.reference_data.localize_entries('_ref_entries', '_ref_cols')
    vht = vds.variant_data.localize_entries('_var_entries', '_var_cols')

    # drop 'alleles' key for join
    vht = vht.key_by('locus')

    merged_schema = {}
    for e in vds.reference_data.entry:
        merged_schema[e] = vds.reference_data[e].dtype

    for e in vds.variant_data.entry:
        if e in merged_schema:
            if not merged_schema[e] == vds.variant_data[e].dtype:
                raise TypeError(f"cannot unify field {e!r}: {merged_schema[e]}, {vds.variant_data[e].dtype}")
        else:
            merged_schema[e] = vds.variant_data[e].dtype

    ht = vht.join(rht, how='outer').drop('_ref_cols')

    def merge_arrays(r_array, v_array):
        def rewrite_ref(r):
            ref_block_selector = {}
            for k, t in merged_schema.items():
                if k == 'LA':
                    ref_block_selector[k] = hl.literal([0])
                elif k in ('LGT', 'GT') and k not in r:
                    ref_block_selector[k] = hl.call(0, 0)
                else:
                    ref_block_selector[k] = r[k] if k in r else hl.missing(t)
            return r.select(**ref_block_selector)

        def rewrite_var(v):
            return v.select(**{k: v[k] if k in v else hl.missing(t) for k, t in merged_schema.items()})

        return (
            hl.case()
            .when(hl.is_missing(r_array), v_array.map(rewrite_var))
            .when(hl.is_missing(v_array), r_array.map(rewrite_ref))
            .default(hl.zip(r_array, v_array).map(lambda t: hl.coalesce(rewrite_var(t[1]), rewrite_ref(t[0]))))
        )

    if ref_allele_function is None:
        rg = ht.locus.dtype.reference_genome
        if 'ref_allele' in ht.row:

            def ref_allele_function(ht):
                return ht.ref_allele

        elif rg.has_sequence():

            def ref_allele_function(ht):
                return ht.locus.sequence_context()

            info("to_merged_sparse_mt: using locus sequence context to fill in reference alleles at monomorphic loci.")
        else:
            raise ValueError(
                "to_merged_sparse_mt: in order to construct a ref allele for reference-only sites, "
                "either pass a function to fill in reference alleles (e.g. ref_allele_function=lambda locus: hl.missing('str'))"
                " or add a sequence file with 'hl.get_reference(RG_NAME).add_sequence(FASTA_PATH)'."
            )
    ht = ht.select(
        alleles=hl.coalesce(ht['alleles'], hl.array([ref_allele_function(ht)])),
        # handle cases where vmt is not keyed by alleles
        **{k: ht[k] for k in vds.variant_data.row_value if k != 'alleles'},
        _entries=merge_arrays(ht['_ref_entries'], ht['_var_entries']),
    )
    ht = ht._key_by_assert_sorted('locus', 'alleles')
    return ht._unlocalize_entries('_entries', '_var_cols', list(vds.variant_data.col_key))


@typecheck(vds=VariantDataset, samples=oneof(Table, expr_array(expr_str)), keep=bool, remove_dead_alleles=bool)
def filter_samples(
    vds: 'VariantDataset', samples, *, keep: bool = True, remove_dead_alleles: bool = False
) -> 'VariantDataset':
    """Filter samples in a :class:`.VariantDataset`.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    samples : :class:`.Table` or list of str
        Samples to keep or remove.
    keep : :obj:`bool`
        Whether to keep (default), or filter out the samples from `samples_table`.
    remove_dead_alleles : :obj:`bool`
        If true, remove alleles observed in no samples. Alleles with AC == 0 will be
        removed, and LA values recalculated.

    Returns
    -------
    :class:`.VariantDataset`
    """
    if not isinstance(samples, hl.Table):
        samples = hl.Table.parallelize(samples.map(lambda s: hl.struct(s=s)), key='s')
    if not list(samples[x].dtype for x in samples.key) == [hl.tstr]:
        raise TypeError(f'invalid key: {samples.key.dtype}')
    samples_to_keep = samples.aggregate(hl.agg.collect_as_set(samples.key[0]), _localize=False)._persist()
    reference_data = vds.reference_data.filter_cols(samples_to_keep.contains(vds.reference_data.col_key[0]), keep=keep)
    reference_data = reference_data.filter_rows(hl.agg.count() > 0)
    variant_data = vds.variant_data.filter_cols(samples_to_keep.contains(vds.variant_data.col_key[0]), keep=keep)

    if remove_dead_alleles:
        vd = variant_data
        vd = vd.annotate_rows(__allele_counts=hl.agg.explode(lambda x: hl.agg.counter(x), vd.LA), __n=hl.agg.count())
        vd = vd.filter_rows(vd.__n > 0)
        vd = vd.drop('__n')

        vd = vd.annotate_rows(
            __kept_indices=hl.dict(
                hl.enumerate(
                    hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)),
                    index_first=False,
                )
            )
        )

        vd = vd.annotate_rows(
            __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1))
        )

        def new_la_index(old_idx):
            raw_idx = vd.__old_to_new_LA[old_idx]
            return (
                hl.case()
                .when(raw_idx >= 0, raw_idx)
                .or_error("'filter_samples': unexpected local allele: old index=" + hl.str(old_idx))
            )

        vd = vd.annotate_entries(LA=vd.LA.map(lambda la: new_la_index(la)))
        vd = vd.key_rows_by('locus')
        vd = vd.annotate_rows(alleles=vd.__kept_indices.keys().map(lambda i: vd.alleles[i]))
        vd = vd._key_rows_by_assert_sorted('locus', 'alleles')
        vd = vd.drop('__allele_counts', '__kept_indices', '__old_to_new_LA')
        return VariantDataset(reference_data, vd)

    variant_data = variant_data.filter_rows(hl.agg.count() > 0)
    return VariantDataset(reference_data, variant_data)


@typecheck(mt=MatrixTable, normalization_contig=str)
def impute_sex_chr_ploidy_from_interval_coverage(
    mt: 'MatrixTable',
    normalization_contig: str,
) -> 'Table':
    """Impute sex chromosome ploidy from a precomputed interval coverage MatrixTable.

    The input MatrixTable must have the following row fields:

     - ``interval`` (*interval*): Genomic interval of interest.
     - ``interval_size`` (*int32*): Size of interval, in bases.

    And the following entry fields:

     -  ``sum_dp`` (*int64*): Sum of depth values by base across the interval.

    Returns a :class:`.Table` with sample ID keys, with the following fields:

     -  ``autosomal_mean_dp`` (*float64*): Mean depth on calling intervals on normalization contig.
     -  ``x_mean_dp`` (*float64*): Mean depth on calling intervals on X chromosome.
     -  ``x_ploidy`` (*float64*): Estimated ploidy on X chromosome. Equal to ``2 * x_mean_dp / autosomal_mean_dp``.
     -  ``y_mean_dp`` (*float64*): Mean depth on calling intervals on  chromosome.
     -  ``y_ploidy`` (*float64*): Estimated ploidy on Y chromosome. Equal to ``2 * y_mean_db / autosomal_mean_dp``.

    Parameters
    ----------
    mt : :class:`.MatrixTable`
        Interval-by-sample MatrixTable with sum of depth values across the interval.
    normalization_contig : str
        Autosomal contig for depth comparison.

    Returns
    -------
    :class:`.Table`
    """

    rg = mt.interval.start.dtype.reference_genome

    if len(rg.x_contigs) != 1:
        raise NotImplementedError(
            f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'"
        )
    chr_x = rg.x_contigs[0]
    if len(rg.y_contigs) != 1:
        raise NotImplementedError(
            f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chr_ploidy_from_interval_coverage'"
        )
    chr_y = rg.y_contigs[0]

    mt = mt.annotate_rows(contig=mt.interval.start.contig)
    mt = mt.annotate_cols(__mean_dp=hl.agg.group_by(mt.contig, hl.agg.sum(mt.sum_dp) / hl.agg.sum(mt.interval_size)))

    mean_dp_dict = mt.__mean_dp
    auto_dp = mean_dp_dict.get(normalization_contig, 0.0)
    x_dp = mean_dp_dict.get(chr_x, 0.0)
    y_dp = mean_dp_dict.get(chr_y, 0.0)
    per_sample = mt.transmute_cols(
        autosomal_mean_dp=auto_dp,
        x_mean_dp=x_dp,
        x_ploidy=2 * x_dp / auto_dp,
        y_mean_dp=y_dp,
        y_ploidy=2 * y_dp / auto_dp,
    )
    info("'impute_sex_chromosome_ploidy': computing and checkpointing coverage and karyotype metrics")
    return per_sample.cols().checkpoint(new_temp_file('impute_sex_karyotype', extension='ht'))


@typecheck(
    vds=VariantDataset,
    calling_intervals=oneof(Table, expr_array(expr_interval(expr_locus()))),
    normalization_contig=str,
    use_variant_dataset=bool,
)
def impute_sex_chromosome_ploidy(
    vds: VariantDataset, calling_intervals, normalization_contig: str, use_variant_dataset: bool = False
) -> Table:
    """Impute sex chromosome ploidy from depth of reference or variant data within calling intervals.

    Returns a :class:`.Table` with sample ID keys, with the following fields:

     -  ``autosomal_mean_dp`` (*float64*): Mean depth on calling intervals on normalization contig.
     -  ``x_mean_dp`` (*float64*): Mean depth on calling intervals on X chromosome.
     -  ``x_ploidy`` (*float64*): Estimated ploidy on X chromosome. Equal to ``2 * x_mean_dp / autosomal_mean_dp``.
     -  ``y_mean_dp`` (*float64*): Mean depth on calling intervals on  chromosome.
     -  ``y_ploidy`` (*float64*): Estimated ploidy on Y chromosome. Equal to ``2 * y_mean_db / autosomal_mean_dp``.

    Parameters
    ----------
    vds : vds: :class:`.VariantDataset`
        Dataset.
    calling_intervals : :class:`.Table` or :class:`.ArrayExpression`
        Calling intervals with consistent read coverage (for exomes, trim the capture intervals).
    normalization_contig : str
        Autosomal contig for depth comparison.
    use_variant_dataset : bool
        Whether to use depth of variant data within calling intervals instead of reference data. Default will use reference data.

    Returns
    -------
    :class:`.Table`
    """

    if not isinstance(calling_intervals, Table):
        calling_intervals = hl.Table.parallelize(
            hl.map(lambda i: hl.struct(interval=i), calling_intervals),
            schema=hl.tstruct(interval=calling_intervals.dtype.element_type),
            key='interval',
        )
    else:
        key_dtype = calling_intervals.key.dtype
        if (
            len(key_dtype) != 1
            or not isinstance(calling_intervals.key[0].dtype, hl.tinterval)
            or calling_intervals.key[0].dtype.point_type != vds.reference_data.locus.dtype
        ):
            raise ValueError(
                f"'impute_sex_chromosome_ploidy': expect calling_intervals to be list of intervals or"
                f" table with single key of type interval<locus>, found table with key: {key_dtype}"
            )

    rg = vds.reference_data.locus.dtype.reference_genome

    par_boundaries = []
    for par_interval in rg.par:
        par_boundaries.append(par_interval.start)
        par_boundaries.append(par_interval.end)

    # segment on PAR interval boundaries
    calling_intervals = hl.segment_intervals(calling_intervals, par_boundaries)

    # remove intervals overlapping PAR
    calling_intervals = calling_intervals.filter(
        hl.all(lambda x: ~x.overlaps(calling_intervals.interval), hl.literal(rg.par))
    )

    # checkpoint for efficient multiple downstream usages
    info("'impute_sex_chromosome_ploidy': checkpointing calling intervals")
    calling_intervals = calling_intervals.checkpoint(new_temp_file(extension='ht'))

    interval = calling_intervals.key[0]
    (any_bad_intervals, chrs_represented) = calling_intervals.aggregate((
        hl.agg.any(interval.start.contig != interval.end.contig),
        hl.agg.collect_as_set(interval.start.contig),
    ))
    if any_bad_intervals:
        raise ValueError(
            "'impute_sex_chromosome_ploidy' does not support calling intervals that span chromosome boundaries"
        )

    if len(rg.x_contigs) != 1:
        raise NotImplementedError(
            f"reference genome {rg.name!r} has multiple X contigs, this is not supported in 'impute_sex_chromosome_ploidy'"
        )
    if len(rg.y_contigs) != 1:
        raise NotImplementedError(
            f"reference genome {rg.name!r} has multiple Y contigs, this is not supported in 'impute_sex_chromosome_ploidy'"
        )

    kept_contig_filter = hl.array(chrs_represented).map(lambda x: hl.parse_locus_interval(x, reference_genome=rg))
    vds = VariantDataset(
        hl.filter_intervals(vds.reference_data, kept_contig_filter),
        hl.filter_intervals(vds.variant_data, kept_contig_filter),
    )

    if use_variant_dataset:
        mt = vds.variant_data
        calling_intervals = calling_intervals.annotate(interval_dup=interval)
        mt = mt.annotate_rows(interval=calling_intervals[mt.locus].interval_dup)
        mt = mt.filter_rows(hl.is_defined(mt.interval))
        coverage = mt.select_entries(sum_dp=mt.DP, interval_size=hl.is_defined(mt.DP))
    else:
        coverage = interval_coverage(vds, calling_intervals, gq_thresholds=()).drop('gq_thresholds')

    return impute_sex_chr_ploidy_from_interval_coverage(coverage, normalization_contig)


@typecheck(vds=VariantDataset, variants_table=Table, keep=bool)
def filter_variants(vds: 'VariantDataset', variants_table: 'Table', *, keep: bool = True) -> 'VariantDataset':
    """Filter variants in a :class:`.VariantDataset`, without removing reference
    data.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    variants_table : :class:`.Table`
        Variants to filter on.
    keep: :obj:`bool`
        Whether to keep (default), or filter out the variants from `variants_table`.

    Returns
    -------
    :class:`.VariantDataset`.
    """
    if keep:
        variant_data = vds.variant_data.semi_join_rows(variants_table)
    else:
        variant_data = vds.variant_data.anti_join_rows(variants_table)
    return VariantDataset(vds.reference_data, variant_data)


@typecheck(
    vds=VariantDataset,
    intervals=oneof(Table, expr_array(expr_interval(expr_any))),
    keep=bool,
    mode=enumeration('variants_only', 'split_at_boundaries', 'unchecked_filter_both'),
)
def _parameterized_filter_intervals(vds: 'VariantDataset', intervals, keep: bool, mode: str) -> 'VariantDataset':
    intervals_table = None
    if isinstance(intervals, Table):
        expected = hl.tinterval(hl.tlocus(vds.reference_genome))
        if len(intervals.key) != 1 or intervals.key[0].dtype != hl.tinterval(hl.tlocus(vds.reference_genome)):
            raise ValueError(
                f"'filter_intervals': expect a table with a single key of type {expected}; "
                f"found {list(intervals.key.dtype.values())}"
            )
        intervals_table = intervals
        intervals = hl.literal(intervals.aggregate(hl.agg.collect(intervals.key[0]), _localize=False))

    if mode == 'unchecked_filter_both':
        return VariantDataset(
            hl.filter_intervals(vds.reference_data, intervals, keep),
            hl.filter_intervals(vds.variant_data, intervals, keep),
        )

    reference_data = vds.reference_data
    if keep:
        rbml = hl.vds.VariantDataset.ref_block_max_length_field
        if rbml in vds.reference_data.globals:
            max_len = hl.eval(vds.reference_data.index_globals()[rbml])
            ref_intervals = intervals.map(
                lambda interval: hl.interval(
                    interval.start - (max_len - 1), interval.end, interval.includes_start, interval.includes_end
                )
            )
            reference_data = hl.filter_intervals(reference_data, ref_intervals, keep)
        else:
            warning(
                "'hl.vds.filter_intervals': filtering intervals without a known max reference block length"
                "\n  (computed by `hl.vds.store_ref_block_max_length` or 'hl.vds.truncate_reference_blocks')"
                "\n  requires a full pass over the reference data (expensive!)"
            )

    if mode == 'variants_only':
        variant_data = hl.filter_intervals(vds.variant_data, intervals, keep)
        return VariantDataset(reference_data, variant_data)
    if mode == 'split_at_boundaries':
        if not keep:
            raise ValueError("filter_intervals mode 'split_at_boundaries' not implemented for keep=False")
        par_intervals = intervals_table or hl.Table.parallelize(
            intervals.map(lambda x: hl.struct(interval=x)),
            schema=hl.tstruct(interval=intervals.dtype.element_type),
            key='interval',
        )
        ref = segment_reference_blocks(reference_data, par_intervals).drop(
            'interval_end', next(iter(par_intervals.key))
        )
        return VariantDataset(ref, hl.filter_intervals(vds.variant_data, intervals, keep))


@typecheck(
    vds=VariantDataset,
    keep=nullable(oneof(str, sequenceof(str))),
    remove=nullable(oneof(str, sequenceof(str))),
    keep_autosomes=bool,
)
def filter_chromosomes(vds: 'VariantDataset', *, keep=None, remove=None, keep_autosomes=False) -> 'VariantDataset':
    """Filter chromosomes of a :class:`.VariantDataset` in several possible modes.

    Notes
    -----
    There are three modes for :func:`filter_chromosomes`, based on which argument is passed
    to the function. Exactly one of the below arguments must be passed by keyword.

     - ``keep``: This argument expects a single chromosome identifier or a list of chromosome
       identifiers, and the function returns a :class:`.VariantDataset` with only those
       chromosomes.
     - ``remove``: This argument expects a single chromosome identifier or a list of chromosome
       identifiers, and the function returns a :class:`.VariantDataset` with those chromosomes
       removed.
     - ``keep_autosomes``: This argument expects the value ``True``, and returns a dataset without
       sex and mitochondrial chromosomes.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset.
    keep
        Keep a specified list of contigs.
    remove
        Remove a specified list of contigs
    keep_autosomes
        If true, keep only autosomal chromosomes.

    Returns
    -------
    :class:`.VariantDataset`.
    """

    n_args_passed = (keep is not None) + (remove is not None) + keep_autosomes
    if n_args_passed == 0:
        raise ValueError("filter_chromosomes: expect one of 'keep', 'remove', or 'keep_autosomes' arguments")
    if n_args_passed > 1:
        raise ValueError(
            "filter_chromosomes: expect ONLY one of 'keep', 'remove', or 'keep_autosomes' arguments"
            "\n  In order use 'keep_autosomes' with 'keep' or 'remove', call the function twice"
        )

    rg = vds.reference_genome

    to_keep = []

    if keep is not None:
        keep = wrap_to_list(keep)
        to_keep.extend(keep)
    elif remove is not None:
        remove = set(wrap_to_list(remove))
        for c in rg.contigs:
            if c not in remove:
                to_keep.append(c)
    elif keep_autosomes:
        to_remove = set(rg.x_contigs + rg.y_contigs + rg.mt_contigs)
        for c in rg.contigs:
            if c not in to_remove:
                to_keep.append(c)

    parsed_intervals = hl.literal(to_keep, hl.tarray(hl.tstr)).map(
        lambda c: hl.parse_locus_interval(c, reference_genome=rg)
    )
    return _parameterized_filter_intervals(vds, intervals=parsed_intervals, keep=True, mode='unchecked_filter_both')


@typecheck(
    vds=VariantDataset,
    intervals=oneof(Table, expr_array(expr_interval(expr_any))),
    split_reference_blocks=bool,
    keep=bool,
)
def filter_intervals(
    vds: 'VariantDataset', intervals, *, split_reference_blocks: bool = False, keep: bool = True
) -> 'VariantDataset':
    """Filter intervals in a :class:`.VariantDataset`.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    intervals : :class:`.Table` or :class:`.ArrayExpression` of type :class:`.tinterval`
        Intervals to filter on.
    split_reference_blocks: :obj:`bool`
        If true, remove reference data outside the given intervals by segmenting reference
        blocks at interval boundaries. Results in a smaller result, but this filter mode
        is more computationally expensive to evaluate.
    keep : :obj:`bool`
        Whether to keep, or filter out (default) rows that fall within any
        interval in `intervals`.

    Returns
    -------
    :class:`.VariantDataset`
    """
    if split_reference_blocks and not keep:
        raise ValueError("'filter_intervals': cannot use 'split_reference_blocks' with keep=False")
    return _parameterized_filter_intervals(
        vds, intervals, keep=keep, mode='split_at_boundaries' if split_reference_blocks else 'variants_only'
    )


@typecheck(vds=VariantDataset, filter_changed_loci=bool)
def split_multi(vds: 'VariantDataset', *, filter_changed_loci: bool = False) -> 'VariantDataset':
    """Split the multiallelic variants in a :class:`.VariantDataset`.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    filter_changed_loci : :obj:`bool`
        If any REF/ALT pair changes locus under :func:`.min_rep`, filter that
        variant instead of throwing an error.

    Returns
    -------
    :class:`.VariantDataset`
    """
    variant_data = hl.experimental.sparse_split_multi(vds.variant_data, filter_changed_loci=filter_changed_loci)
    reference_data = vds.reference_data

    if 'LGT' in reference_data.entry:
        if 'GT' in reference_data.entry:
            reference_data = reference_data.drop('LGT')
        else:
            reference_data = reference_data.transmute_entries(GT=reference_data.LGT)

    return VariantDataset(reference_data=reference_data, variant_data=variant_data)


@typecheck(ref=MatrixTable, intervals=Table)
def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixTable':
    """Returns a matrix table of reference blocks segmented according to intervals.

    Loci outside the given intervals are discarded. Reference blocks that start before
    but span an interval will appear at the interval start locus.

    Note
    ----
        Assumes disjoint intervals which do not span contigs.

        Requires start-inclusive intervals.

    Parameters
    ----------
    ref : :class:`.MatrixTable`
        MatrixTable of reference blocks.
    intervals : :class:`.Table`
        Table of intervals at which to segment reference blocks.

    Returns
    -------
    :class:`.MatrixTable`
    """
    interval_field = next(iter(intervals.key))
    if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype):
        raise ValueError(
            f"expect intervals to be keyed by intervals of loci matching the VariantDataset:"
            f" found {intervals[interval_field].dtype} / {ref.locus.dtype}"
        )
    intervals = intervals.select(_interval_dup=intervals[interval_field])

    if not intervals.aggregate(
        hl.agg.all(
            intervals[interval_field].includes_start
            & (intervals[interval_field].start.contig == intervals[interval_field].end.contig)
        )
    ):
        raise ValueError("expect intervals to be start-inclusive")

    starts = intervals.key_by(_start_locus=intervals[interval_field].start)
    starts = starts.annotate(_include_locus=True)
    refl = ref.localize_entries('_ref_entries', '_ref_cols')
    joined = refl.join(starts, how='outer')
    rg = ref.locus.dtype.reference_genome
    contigs = rg.contigs
    contig_idx_map = hl.literal({contigs[i]: i for i in range(len(contigs))}, 'dict<str, int32>')
    joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig])
    joined = joined.annotate(
        _ref_entries=joined._ref_entries.map(lambda e: e.annotate(__contig_idx=joined.__contig_idx))
    )
    dense = joined.annotate(
        dense_ref=hl.or_missing(
            joined._include_locus,
            hl.rbind(
                joined.locus.position,
                lambda pos: hl.enumerate(hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)).map(
                    lambda idx_and_e: hl.rbind(
                        idx_and_e[0],
                        idx_and_e[1],
                        lambda idx, e: hl.coalesce(
                            joined._ref_entries[idx],
                            hl.or_missing((e.__contig_idx == joined.__contig_idx) & (e.END >= pos), e),
                        ),
                    ).drop('__contig_idx')
                ),
            ),
        )
    )
    dense = dense.filter(dense._include_locus).drop('_interval_dup', '_include_locus', '__contig_idx')

    # at this point, 'dense' is a table with dense rows of reference blocks, keyed by locus

    refl_filtered = refl.annotate(**{interval_field: intervals[refl.locus]._interval_dup})

    # remove rows that are not contained in an interval, and rows that are the start of an
    # interval (interval starts come from the 'dense' table)
    refl_filtered = refl_filtered.filter(
        hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start)
    )

    # union dense interval starts with filtered table
    refl_filtered = refl_filtered.union(dense.transmute(_ref_entries=dense.dense_ref))

    # rewrite reference blocks to end at the first of (interval end, reference block end)
    refl_filtered = refl_filtered.annotate(
        interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end
    )
    refl_filtered = refl_filtered.annotate(
        _ref_entries=refl_filtered._ref_entries.map(
            lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.interval_end))
        )
    )

    return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols', list(ref.col_key))


@typecheck(
    vds=VariantDataset,
    intervals=Table,
    gq_thresholds=sequenceof(int),
    dp_thresholds=sequenceof(int),
    dp_field=nullable(str),
)
def interval_coverage(
    vds: VariantDataset,
    intervals: Table,
    gq_thresholds=(
        0,
        10,
        20,
    ),
    dp_thresholds=(0, 1, 10, 20, 30),
    dp_field=None,
) -> 'MatrixTable':
    """Compute statistics about base coverage by interval.

    Returns a :class:`.MatrixTable` with interval row keys and sample column keys.

    Contains the following row fields:
     - ``interval`` (*interval*): Genomic interval of interest.
     - ``interval_size`` (*int32*): Size of interval, in bases.


    Computes the following entry fields:

     -  ``bases_over_gq_threshold`` (*tuple of int64*): Number of bases in the interval
        over each GQ threshold.
     -  ``fraction_over_gq_threshold`` (*tuple of float64*): Fraction of interval (in bases)
        above each GQ threshold. Computed by dividing each member of *bases_over_gq_threshold*
        by *interval_size*.
     -  ``bases_over_dp_threshold`` (*tuple of int64*): Number of bases in the interval
        over each DP threshold.
     -  ``fraction_over_dp_threshold`` (*tuple of float64*): Fraction of interval (in bases)
        above each DP threshold. Computed by dividing each member of *bases_over_dp_threshold*
        by *interval_size*.
     -  ``sum_dp`` (*int64*): Sum of depth values by base across the interval.
     -  ``mean_dp`` (*float64*): Mean depth of bases across the interval. Computed by dividing
        *sum_dp* by *interval_size*.

    If the `dp_field` parameter is not specified, the ``DP`` is used for depth
    if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP``
    or ``MIN_DP`` field is present, no depth statistics will be calculated.

    Note
    ----
    The metrics computed by this method are computed **only from reference blocks**. Most
    variant callers produce data where non-reference calls interrupt reference blocks, and
    so the metrics computed here are slight underestimates of the true values (which would
    include the quality/depth of non-reference calls). This is likely a negligible difference,
    but is something to be aware of, especially as it interacts with samples of
    ancestral backgrounds with more or fewer non-reference calls.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
    intervals : :class:`.Table`
        Table of intervals. Must be start-inclusive, and cannot span contigs.
    gq_thresholds : tuple of int
        GQ thresholds.
    dp_field : str, optional
        Field for depth calculation. Uses DP or MIN_DP by default (with priority for DP if present).

    Returns
    -------
    :class:`.MatrixTable`
        Interval-by-sample matrix
    """
    ref = vds.reference_data
    split = segment_reference_blocks(ref, intervals)
    intervals = intervals.annotate(interval_dup=intervals.key[0])

    if 'DP' in ref.entry:
        dp_field_to_use = 'DP'
    elif 'MIN_DP' in ref.entry:
        dp_field_to_use = 'MIN_DP'
    else:
        dp_field_to_use = dp_field

    ref_block_length = split.END - split.locus.position + 1
    if dp_field_to_use is not None:
        dp = split[dp_field_to_use]
        dp_field_dict = {
            'sum_dp': hl.agg.sum(ref_block_length * dp),
            'bases_over_dp_threshold': tuple(
                hl.agg.filter(dp >= dp_threshold, hl.agg.sum(ref_block_length)) for dp_threshold in dp_thresholds
            ),
        }
    else:
        dp_field_dict = dict()

    per_interval = split.group_rows_by(interval=intervals[split.row_key[0]].interval_dup).aggregate(
        bases_over_gq_threshold=tuple(
            hl.agg.filter(split.GQ >= gq_threshold, hl.agg.sum(ref_block_length)) for gq_threshold in gq_thresholds
        ),
        **dp_field_dict,
    )

    interval = per_interval.interval
    interval_size = (
        interval.end.position + interval.includes_end - interval.start.position - 1 + interval.includes_start
    )
    per_interval = per_interval.annotate_rows(interval_size=interval_size)

    dp_mod_dict = {}
    if dp_field_to_use is not None:
        dp_mod_dict['fraction_over_dp_threshold'] = tuple(
            hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_dp_threshold
        )
        dp_mod_dict['mean_dp'] = per_interval.sum_dp / per_interval.interval_size

    per_interval = per_interval.annotate_entries(
        fraction_over_gq_threshold=tuple(
            hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_gq_threshold
        ),
        **dp_mod_dict,
    )

    per_interval = per_interval.annotate_globals(gq_thresholds=hl.tuple(gq_thresholds))

    return per_interval


@typecheck(
    ds=oneof(MatrixTable, VariantDataset),
    max_ref_block_base_pairs=nullable(int),
    ref_block_winsorize_fraction=nullable(float),
)
def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, ref_block_winsorize_fraction=None):
    """Cap reference blocks at a maximum length in order to permit faster interval filtering.

    Examples
    --------
    Truncate reference blocks to 5 kilobases:

    >>> vds2 = hl.vds.truncate_reference_blocks(vds, max_ref_block_base_pairs=5000) # doctest: +SKIP

    Truncate the longest 1% of reference blocks to the length of the 99th percentile block:

    >>> vds2 = hl.vds.truncate_reference_blocks(vds, ref_block_winsorize_fraction=0.01) # doctest: +SKIP

    Notes
    -----
    After this function has been run, the reference blocks have a known maximum length `ref_block_max_length`,
    stored in the global fields, which permits :func:`.vds.filter_intervals` to filter to intervals of the reference
    data by reading `ref_block_max_length` bases ahead of each interval. This allows narrow interval queries
    to run in roughly O(data kept) work rather than O(all reference data) work.

    It is also possible to patch an existing VDS to store the max reference block length with :func:`.vds.store_ref_block_max_length`.

    See Also
    --------
    :func:`.vds.store_ref_block_max_length`.

    Parameters
    ----------
    vds : :class:`.VariantDataset` or :class:`.MatrixTable`
    max_ref_block_base_pairs
        Maximum size of reference blocks, in base pairs.
    ref_block_winsorize_fraction
        Fraction of reference block length distribution to truncate / winsorize.

    Returns
    -------
    :class:`.VariantDataset` or :class:`.MatrixTable`
    """
    if isinstance(ds, VariantDataset):
        rd = ds.reference_data
    else:
        rd = ds

    fd_name = hl.vds.VariantDataset.ref_block_max_length_field
    if fd_name in rd.globals:
        rd = rd.drop(fd_name)

    if int(ref_block_winsorize_fraction is None) + int(max_ref_block_base_pairs is None) != 1:
        raise ValueError(
            'truncate_reference_blocks: require exactly one of "max_ref_block_base_pairs", "ref_block_winsorize_fraction"'
        )

    if ref_block_winsorize_fraction is not None:
        assert ref_block_winsorize_fraction > 0 and ref_block_winsorize_fraction < 1, (
            'truncate_reference_blocks: "ref_block_winsorize_fraction" must be between 0 and 1 (e.g. 0.01 to truncate the top 1% of reference blocks)'
        )
        if ref_block_winsorize_fraction > 0.1:
            warning(
                f"'truncate_reference_blocks': ref_block_winsorize_fraction of {ref_block_winsorize_fraction} will lead to significant data duplication,"
                f" recommended values are <0.05."
            )
        max_ref_block_base_pairs = rd.aggregate_entries(
            hl.agg.approx_quantiles(rd.LEN, 1 - ref_block_winsorize_fraction, k=200)
        )

    assert max_ref_block_base_pairs > 0, (
        'truncate_reference_blocks: "max_ref_block_base_pairs" must be between greater than zero'
    )
    info(f"splitting VDS reference blocks at {max_ref_block_base_pairs} base pairs")

    rd_under_limit = rd.filter_entries(rd.LEN <= max_ref_block_base_pairs).localize_entries('fixed_blocks', 'cols')

    rd_over_limit = rd.filter_entries(rd.LEN > max_ref_block_base_pairs).key_cols_by(col_idx=hl.scan.count())
    rd_over_limit = rd_over_limit.select_rows().select_cols().key_rows_by().key_cols_by()
    es = rd_over_limit.entries()
    es = es.annotate(new_start=hl.range(es.locus.position, es.locus.position + es.LEN, max_ref_block_base_pairs))
    es = es.explode('new_start')
    es = es.transmute(
        locus=hl.locus(es.locus.contig, es.new_start, reference_genome=es.locus.dtype.reference_genome),
        LEN=hl.min(
            es.locus.position + es.LEN - es.new_start,
            max_ref_block_base_pairs,
        ),
    )
    # we've changed LEN so we need to make sure that END is correct
    if 'END' in es.row:
        es = es.annotate(END=es.LEN + es.locus.position - 1)
    es = es.key_by(es.locus).collect_by_key("new_blocks")
    es = es.transmute(moved_blocks_dict=hl.dict(es.new_blocks.map(lambda x: (x.col_idx, x.drop('col_idx')))))

    joined = rd_under_limit.join(es, how='outer')
    joined = joined.transmute(
        merged_blocks=hl.range(hl.len(joined.cols)).map(
            lambda idx: hl.coalesce(joined.moved_blocks_dict.get(idx), joined.fixed_blocks[idx])
        )
    )
    new_rd = joined._unlocalize_entries(
        entries_field_name='merged_blocks', cols_field_name='cols', col_key=list(rd.col_key)
    )
    new_rd = new_rd.annotate_globals(**{fd_name: max_ref_block_base_pairs})

    if isinstance(ds, hl.vds.VariantDataset):
        return VariantDataset(reference_data=new_rd, variant_data=ds.variant_data)
    return new_rd


@typecheck(
    ds=oneof(MatrixTable, VariantDataset),
    equivalence_function=func_spec(2, expr_bool),
    merge_functions=nullable(dictof(str, oneof(str, func_spec(1, expr_any)))),
)
def merge_reference_blocks(ds, equivalence_function, merge_functions=None):
    """Merge adjacent reference blocks according to user equivalence criteria.

    Examples
    --------
    Coarsen GQ granularity into bins of 10 and merges blocks with the same GQ in order to
    compress reference data.

    >>> rd = vds.reference_data # doctest: +SKIP
    >>> vds.reference_data = rd.annotate_entries(GQ = rd.GQ - rd.GQ % 10) # doctest: +SKIP
    >>> vds2 = hl.vds.merge_reference_blocks(vds,
    ...                                      equivalence_function=lambda block1, block2: block1.GQ == block2.GQ),
    ...                                      merge_functions={'MIN_DP': 'min'}) # doctest: +SKIP

    Notes
    -----
    The `equivalence_function` argument expects a function from two reference blocks to a
    boolean value indicating whether they should be combined. Adjacency checks are builtin
    to the method (two reference blocks are 'adjacent' if the END of one block is one base
    before the beginning of the next).

    The `merge_functions`

    Parameters
    ----------
    ds : :class:`.VariantDataset` or :class:`.MatrixTable`
        Variant dataset or reference block matrix table.
    Returns
    -------
    :class:`.VariantDataset` or :class:`.MatrixTable`
    """
    if isinstance(ds, VariantDataset):
        rd = ds.reference_data
    else:
        rd = ds
    rd = rd.annotate_rows(contig_idx_row=rd.locus.contig_idx, start_pos_row=rd.locus.position)
    rd = rd.annotate_entries(contig_idx=rd.contig_idx_row, start_pos=rd.start_pos_row)
    ht = rd.localize_entries('entries', 'cols')

    def merge(block1, block2):
        new_fields = {'END': block2.END}
        if merge_functions:
            for k, f in merge_functions.items():
                if isinstance(f, str):
                    _f = f.lower()
                    if _f == 'min':

                        def __f(b1, b2):
                            return hl.min(block1[k], block2[k])

                    elif _f == 'max':

                        def __f(b1, b2):
                            return hl.max(block1[k], block2[k])

                    elif _f == 'sum':

                        def __f(b1, b2):
                            return block1[k] + block2[k]

                    else:
                        raise ValueError(
                            f"merge_reference_blocks: unknown merge function {_f!r},"
                            f" support 'min', 'max', and 'sum' in addition to custom lambdas"
                        )
                new_value = __f(block1, block2)
                if new_value.dtype != block1[k].dtype:
                    raise ValueError(
                        f'merge_reference_blocks: merge_function for {k!r}: new type {new_value.dtype!r} '
                        f'differs from original type {block1[k].dtype!r}'
                    )
                new_fields[k] = new_value
        return block1.annotate(**new_fields)

    def keep_last(t1, t2):
        e1 = t1[0]
        e2 = t2[0]
        are_adjacent = (e1.contig_idx == e2.contig_idx) & (e1.END + 1 == e2.start_pos)
        return hl.if_else(
            hl.is_defined(e1) & hl.is_defined(e2) & are_adjacent & equivalence_function(e1, e2),
            (merge(e1, e2), True),
            t2,
        )

    # approximate a scan that merges before result
    ht = ht.annotate(
        prev_block=hl.zip(
            hl.scan.array_agg(
                lambda elt: hl.scan.fold(
                    (hl.missing(rd.entry.dtype), False), lambda acc: keep_last(acc, (elt, False)), keep_last
                ),
                ht.entries,
            ),
            ht.entries,
        ).map(lambda tup: keep_last(tup[0], (tup[1], False)))
    )
    ht_join = ht

    ht = ht.key_by()
    ht = ht.select(
        to_shuffle=hl.enumerate(ht.prev_block).filter(
            lambda idx_and_elt: hl.is_defined(idx_and_elt[1]) & idx_and_elt[1][1]
        )
    )
    ht = ht.explode('to_shuffle')
    rg = rd.locus.dtype.reference_genome
    ht = ht.transmute(col_idx=ht.to_shuffle[0], entry=ht.to_shuffle[1][0])
    ht_shuf = ht.key_by(
        locus=hl.locus(hl.literal(rg.contigs)[ht.entry.contig_idx], ht.entry.start_pos, reference_genome=rg)
    )

    ht_shuf = ht_shuf.collect_by_key("new_starts")
    # new_starts can contain multiple records for a collapsed ref block, one for each folded block.
    # We want to keep the one with the highest END
    ht_shuf = ht_shuf.select(
        moved_blocks_dict=hl.group_by(lambda elt: elt.col_idx, ht_shuf.new_starts).map_values(
            lambda arr: arr[hl.argmax(arr.map(lambda x: x.entry.END))].entry.drop('contig_idx', 'start_pos')
        )
    )

    ht_joined = ht_join.join(ht_shuf.select_globals(), 'left')

    def merge_f(tup):
        (idx, original_entry) = tup

        return (
            hl.case()
            .when(
                ~(hl.coalesce(ht_joined.prev_block[idx][1], False)),
                hl.coalesce(ht_joined.moved_blocks_dict.get(idx), original_entry.drop('contig_idx', 'start_pos')),
            )
            .or_missing()
        )

    ht_joined = ht_joined.annotate(new_entries=hl.enumerate(ht_joined.entries).map(lambda tup: merge_f(tup)))
    ht_joined = ht_joined.drop('moved_blocks_dict', 'entries', 'prev_block', 'contig_idx_row', 'start_pos_row')
    new_rd = ht_joined._unlocalize_entries(
        entries_field_name='new_entries', cols_field_name='cols', col_key=list(rd.col_key)
    )

    rbml = hl.vds.VariantDataset.ref_block_max_length_field
    if rbml in new_rd.globals:
        new_rd = new_rd.drop(rbml)

    if isinstance(ds, VariantDataset):
        return VariantDataset(reference_data=new_rd, variant_data=ds.variant_data)
    return new_rd
